import gc

import numpy as np

from utils_for_llm import *
import json
import os
from random import randrange
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback, TrainerState, TrainerControl
import argparse
import warnings
import warnings
import os


# Ignore all warnings
warnings.filterwarnings("ignore")



parser = argparse.ArgumentParser()
parser.add_argument("--debug",action="store_true")
parser.add_argument("--train_file", default='./data/query_and_description_3.json', type=str) # or ./data/test_data.json or ./data/ALL_data.json or ./data/full_synthesized_with_seed_data.json
args = parser.parse_args()

with open(args.train_file, 'r') as fp:
    data = json.load(fp)
ICL_example = data[0]


def build_prompt(sample):
    if isinstance(sample.get('apis', np.NAN), float):
        action_names = stat[sample['key']]['action_names']
    else:
        action_names = sample['apis']
    query = sample['query']
    apis_desc = [identifier2python.get(action_name.replace('.', '_')) for action_name in action_names]
    apis_desc = [_ for _ in apis_desc if _ is not None]
    apis_desc = "\n".join(apis_desc)


    system_message = """You are a very helpful AI assistant who can write corresponding Python main code based on user's query and usable Python function interface."""

    user_message = f"""Please generate python main code based on the following query :\n {query}
You can start by using natural language to plan your tool call strategy, and then generate the complete code. For example, `Thought:\n<tool call strategy>\n\nCode:\n```python\n<main code>\n````.
Note that your output should always include `Code:\n```python\n<main code>\n````, formatted accordingly.
Here are some useful function interface you may use:\n {apis_desc}"""

    prompt = [
            {
                "role": "system",
                "content": system_message,
            },
            {
                "role": "user",
                "content": user_message
            },
            {
                "role": "assistant",
                "content": ""
            }
        ]


    return prompt


def build_prompt_ICL(sample):
    if isinstance(sample.get('apis', np.NAN), float):
        action_names = stat[sample['key']]['action_names']
    else:
        action_names = sample['apis']



    query = sample['query']
    apis_desc = [identifier2python.get(action_name.replace('.', '_')) for action_name in action_names]
    apis_desc = [_ for _ in apis_desc if _ is not None]
    apis_desc = "\n".join(apis_desc)


    #
    ICL_action_names = stat[ICL_example['key']]['action_names']
    ICL_apis_desc = [identifier2python.get(action_name.replace('.', '_')) for action_name in ICL_action_names]
    ICL_apis_desc = [_ for _ in ICL_apis_desc if _ is not None] 
    ICL_apis_desc = "\n".join(ICL_apis_desc)


    system_message = f"""You are a very helpful AI assistant who can write corresponding Python main code based on user's query and usable Python function interface. 
    Here is an example. query:{ICL_example['query']}\n Function interface: {ICL_apis_desc} Thought:\n{ICL_example['description']}\n\n Code:\n```python\n{ICL_example['line_by_line']}\n```"""


    user_message = f""" Please generate python main code based on the following query :\n {query}
You can start by using natural language to plan your tool call strategy, and then generate the complete code. For example, `Thought:\n<tool call strategy>\n\nCode:\n```python\n<main code>\n````.
Note that your output should always include `Code:\n```python\n<main code>\n````, formatted accordingly.
Here are some useful function interface you may use:\n {apis_desc}"""

    prompt = [
            {
                "role": "system",
                "content": system_message,
            },
            {
                "role": "user",
                "content": user_message
            },
            {
                "role": "assistant",
                "content": ""
            }
        ]


    return prompt

if __name__ == "__main__":
    with open('./data/dataset_split_keys.json', 'r') as fp:
        dataset_split = json.load(fp)

    with open('./data/dataset_split_keys_ood.json', 'r') as fp:
        dataset_split_ood = json.load(fp)

    train_keys = set(dataset_split_ood['train']) & set(dataset_split['train'])
    train_data = [sample for sample in data if sample['key'] in train_keys]
    data = [sample for sample in data if (sample['key'] in stat.keys() or sample['key'] in {'synthesized_training_data', 'synthesized_ood_test_data'})]

    data = pd.DataFrame(data)

    dev_keys = set(dataset_split['dev'])
    ood_keys = set(dataset_split_ood['dev'])
    val_df = data[data['key'].isin(dev_keys)]
    ood_df = data[data['key'].isin(ood_keys)]

    # =======debug======
    if args.debug:
        val_df = val_df.head(10)
        args.eval_steps = 2
        max_seq_length = 8192

    val_df['prompt'] = val_df.apply(build_prompt_ICL, axis=1)
    ood_df['prompt'] = ood_df.apply(build_prompt_ICL, axis=1)

    prompt_for_API_ICL = {
        'val': val_df,
        'ood': ood_df
    }

    with open('./prompt_for_API_ICL.pkl', 'wb') as fp:
        pickle.dump(prompt_for_API_ICL, fp)


    print('dada')